{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "01f23638",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/miniconda3/envs/cuda117/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(tensor([   0, 9178,   32,    2]), tensor([1, 1, 1, 1]), '<s>how are</s>')"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "from util import TokenizerUtil\n",
    "\n",
    "tokenizer = TokenizerUtil()\n",
    "\n",
    "input_ids, attention_mask = tokenizer.encode('how are you', max_length=4)\n",
    "\n",
    "input_ids, attention_mask, tokenizer.decode(input_ids)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "85e428b1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(7500,\n",
       " {'input_ids': tensor([[    0, 33837,    35,  ...,     1,     1,     1],\n",
       "          [    0, 33837,    35,  ...,     1,     1,     1],\n",
       "          [    0, 33837,    35,  ...,     1,     1,     1],\n",
       "          ...,\n",
       "          [    0, 33837,    35,  ...,     1,     1,     1],\n",
       "          [    0, 33837,    35,  ...,     1,     1,     1],\n",
       "          [    0, 33837,    35,  ...,     1,     1,     1]]),\n",
       "  'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],\n",
       "          [1, 1, 1,  ..., 0, 0, 0],\n",
       "          [1, 1, 1,  ..., 0, 0, 0],\n",
       "          ...,\n",
       "          [1, 1, 1,  ..., 0, 0, 0],\n",
       "          [1, 1, 1,  ..., 0, 0, 0],\n",
       "          [1, 1, 1,  ..., 0, 0, 0]])})"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "dataset = load_dataset('json', data_files='dataset/train.json', split='train')\n",
    "\n",
    "#2,4,4切分,取第1部分\n",
    "dataset = dataset.select(range(15000, 45000))\n",
    "\n",
    "\n",
    "def f(data):\n",
    "    #区分两种生成结果\n",
    "    chosen = data['prompt'] + data['chosen'].swapcase()\n",
    "    rejected = data['prompt'] + data['chosen']\n",
    "\n",
    "    chosen_input_ids, chosen_attention_mask = tokenizer.encode(chosen)\n",
    "    rejected_input_ids, rejected_attention_mask = tokenizer.encode(rejected)\n",
    "\n",
    "    return {\n",
    "        'chosen_input_ids': chosen_input_ids,\n",
    "        'chosen_attention_mask': chosen_attention_mask,\n",
    "        'rejected_input_ids': rejected_input_ids,\n",
    "        'rejected_attention_mask': rejected_attention_mask\n",
    "    }\n",
    "\n",
    "\n",
    "dataset = dataset.map(f)\n",
    "dataset.set_format('torch')\n",
    "\n",
    "\n",
    "def f(data):\n",
    "    chosen_input_ids = [i['chosen_input_ids'] for i in data]\n",
    "    chosen_attention_mask = [i['chosen_attention_mask'] for i in data]\n",
    "    rejected_input_ids = [i['rejected_input_ids'] for i in data]\n",
    "    rejected_attention_mask = [i['rejected_attention_mask'] for i in data]\n",
    "\n",
    "    input_ids = torch.stack(chosen_input_ids + rejected_input_ids, dim=0)\n",
    "    attention_mask = torch.stack(chosen_attention_mask +\n",
    "                                 rejected_attention_mask,\n",
    "                                 dim=0)\n",
    "\n",
    "    return {'input_ids': input_ids, 'attention_mask': attention_mask}\n",
    "\n",
    "\n",
    "loader = torch.utils.data.DataLoader(dataset,\n",
    "                                     collate_fn=f,\n",
    "                                     batch_size=4,\n",
    "                                     shuffle=True,\n",
    "                                     drop_last=True)\n",
    "\n",
    "len(loader), next(iter(loader))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "dfa4c50a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'count_require': 3.31196928, 'count_all': 3.31196928, 'ratio': 1.0}\n"
     ]
    }
   ],
   "source": [
    "from lora import count_params\n",
    "\n",
    "\n",
    "class CriticModel(torch.nn.Module):\n",
    "\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "\n",
    "        from transformers import AutoModel\n",
    "        self.rwtransformer = AutoModel.from_pretrained('facebook/opt-350m',\n",
    "                                                       dropout=0.0)\n",
    "\n",
    "        self.v_head = torch.nn.Linear(512, 1, bias=False)\n",
    "\n",
    "    def forward(self, input_ids, attention_mask):\n",
    "        value = self.rwtransformer(\n",
    "            input_ids=input_ids,\n",
    "            attention_mask=attention_mask).last_hidden_state\n",
    "        value = self.v_head(value).squeeze(-1)\n",
    "\n",
    "        loss_sum = 0.0\n",
    "        value_chosen_sum = 0.0\n",
    "        value_rejected_sum = 0.0\n",
    "        for input_ids_chosen, input_ids_rejected, value_chosen, value_rejected in zip(\n",
    "                input_ids[:4], input_ids[4:], value[:4], value[4:]):\n",
    "\n",
    "            #找出每条回答中的起止索引\n",
    "            start = (\n",
    "                input_ids_chosen == input_ids_rejected).tolist().index(False)\n",
    "\n",
    "            end_chosen = input_ids_chosen.tolist().index(\n",
    "                tokenizer.eos_token_id) + 1\n",
    "            end_rejected = input_ids_rejected.tolist().index(\n",
    "                tokenizer.eos_token_id) + 1\n",
    "            end = max(end_chosen, end_rejected)\n",
    "\n",
    "            value_chosen = value_chosen[start:end]\n",
    "            value_rejected = value_rejected[start:end]\n",
    "\n",
    "            loss = value_chosen - value_rejected\n",
    "            loss = -torch.nn.functional.logsigmoid(loss).mean()\n",
    "\n",
    "            loss_sum += loss\n",
    "            value_chosen_sum += value_chosen.mean().item()\n",
    "            value_rejected_sum += value_rejected.mean().item()\n",
    "\n",
    "        return loss_sum / 4, value_chosen_sum, value_rejected_sum\n",
    "\n",
    "\n",
    "model_critic = CriticModel()\n",
    "\n",
    "count_params(model_critic)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "adb14bab",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "CriticModel(\n",
       "  (rwtransformer): OPTModel(\n",
       "    (decoder): OPTDecoder(\n",
       "      (embed_tokens): Embedding(50272, 512, padding_idx=1)\n",
       "      (embed_positions): OPTLearnedPositionalEmbedding(2050, 1024)\n",
       "      (project_out): Linear(in_features=1024, out_features=512, bias=False)\n",
       "      (project_in): Linear(in_features=512, out_features=1024, bias=False)\n",
       "      (layers): ModuleList(\n",
       "        (0): OPTDecoderLayer(\n",
       "          (self_attn): OPTAttention(\n",
       "            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "          )\n",
       "          (activation_fn): ReLU()\n",
       "          (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "          (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "          (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "        )\n",
       "        (1): OPTDecoderLayer(\n",
       "          (self_attn): OPTAttention(\n",
       "            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "          )\n",
       "          (activation_fn): ReLU()\n",
       "          (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "          (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "          (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "        )\n",
       "        (2): OPTDecoderLayer(\n",
       "          (self_attn): OPTAttention(\n",
       "            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "          )\n",
       "          (activation_fn): ReLU()\n",
       "          (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "          (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "          (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "        )\n",
       "        (3): OPTDecoderLayer(\n",
       "          (self_attn): OPTAttention(\n",
       "            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "          )\n",
       "          (activation_fn): ReLU()\n",
       "          (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "          (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "          (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "        )\n",
       "        (4): OPTDecoderLayer(\n",
       "          (self_attn): OPTAttention(\n",
       "            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "          )\n",
       "          (activation_fn): ReLU()\n",
       "          (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "          (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "          (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "        )\n",
       "        (5): OPTDecoderLayer(\n",
       "          (self_attn): OPTAttention(\n",
       "            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "          )\n",
       "          (activation_fn): ReLU()\n",
       "          (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "          (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "          (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "        )\n",
       "        (6): OPTDecoderLayer(\n",
       "          (self_attn): OPTAttention(\n",
       "            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "          )\n",
       "          (activation_fn): ReLU()\n",
       "          (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "          (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "          (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "        )\n",
       "        (7): OPTDecoderLayer(\n",
       "          (self_attn): OPTAttention(\n",
       "            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "          )\n",
       "          (activation_fn): ReLU()\n",
       "          (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "          (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "          (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "        )\n",
       "        (8): OPTDecoderLayer(\n",
       "          (self_attn): OPTAttention(\n",
       "            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "          )\n",
       "          (activation_fn): ReLU()\n",
       "          (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "          (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "          (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "        )\n",
       "        (9): OPTDecoderLayer(\n",
       "          (self_attn): OPTAttention(\n",
       "            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "          )\n",
       "          (activation_fn): ReLU()\n",
       "          (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "          (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "          (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "        )\n",
       "        (10): OPTDecoderLayer(\n",
       "          (self_attn): OPTAttention(\n",
       "            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "          )\n",
       "          (activation_fn): ReLU()\n",
       "          (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "          (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "          (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "        )\n",
       "        (11): OPTDecoderLayer(\n",
       "          (self_attn): OPTAttention(\n",
       "            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "          )\n",
       "          (activation_fn): ReLU()\n",
       "          (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "          (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "          (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "        )\n",
       "        (12): OPTDecoderLayer(\n",
       "          (self_attn): OPTAttention(\n",
       "            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "          )\n",
       "          (activation_fn): ReLU()\n",
       "          (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "          (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "          (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "        )\n",
       "        (13): OPTDecoderLayer(\n",
       "          (self_attn): OPTAttention(\n",
       "            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "          )\n",
       "          (activation_fn): ReLU()\n",
       "          (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "          (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "          (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "        )\n",
       "        (14): OPTDecoderLayer(\n",
       "          (self_attn): OPTAttention(\n",
       "            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "          )\n",
       "          (activation_fn): ReLU()\n",
       "          (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "          (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "          (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "        )\n",
       "        (15): OPTDecoderLayer(\n",
       "          (self_attn): OPTAttention(\n",
       "            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "          )\n",
       "          (activation_fn): ReLU()\n",
       "          (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "          (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "          (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "        )\n",
       "        (16): OPTDecoderLayer(\n",
       "          (self_attn): OPTAttention(\n",
       "            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "          )\n",
       "          (activation_fn): ReLU()\n",
       "          (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "          (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "          (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "        )\n",
       "        (17): OPTDecoderLayer(\n",
       "          (self_attn): OPTAttention(\n",
       "            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "          )\n",
       "          (activation_fn): ReLU()\n",
       "          (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "          (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "          (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "        )\n",
       "        (18): OPTDecoderLayer(\n",
       "          (self_attn): OPTAttention(\n",
       "            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "          )\n",
       "          (activation_fn): ReLU()\n",
       "          (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "          (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "          (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "        )\n",
       "        (19): OPTDecoderLayer(\n",
       "          (self_attn): OPTAttention(\n",
       "            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "          )\n",
       "          (activation_fn): ReLU()\n",
       "          (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "          (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "          (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "        )\n",
       "        (20): OPTDecoderLayer(\n",
       "          (self_attn): OPTAttention(\n",
       "            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "          )\n",
       "          (activation_fn): ReLU()\n",
       "          (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "          (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "          (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "        )\n",
       "        (21): OPTDecoderLayer(\n",
       "          (self_attn): OPTAttention(\n",
       "            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "          )\n",
       "          (activation_fn): ReLU()\n",
       "          (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "          (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "          (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "        )\n",
       "        (22): OPTDecoderLayer(\n",
       "          (self_attn): OPTAttention(\n",
       "            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "          )\n",
       "          (activation_fn): ReLU()\n",
       "          (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "          (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "          (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "        )\n",
       "        (23): OPTDecoderLayer(\n",
       "          (self_attn): OPTAttention(\n",
       "            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "          )\n",
       "          (activation_fn): ReLU()\n",
       "          (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "          (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "          (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "  )\n",
       "  (v_head): Linear(in_features=512, out_features=1, bias=False)\n",
       ")"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from transformers import get_scheduler\n",
    "from accelerate import Accelerator\n",
    "\n",
    "\n",
    "def f():\n",
    "    params_decay = []\n",
    "    params = []\n",
    "    for name, param in model_critic.named_parameters():\n",
    "        if 'bias' in name or 'norm.weight' in name:\n",
    "            params.append(param)\n",
    "            continue\n",
    "        params_decay.append(param)\n",
    "\n",
    "    return [{\n",
    "        'params': params_decay,\n",
    "        'weight_decay': 0.1\n",
    "    }, {\n",
    "        'params': params,\n",
    "        'weight_decay': 0.0\n",
    "    }]\n",
    "\n",
    "\n",
    "optimizer = torch.optim.Adam(f(), lr=5e-5, betas=(0.9, 0.95))\n",
    "\n",
    "scheduler = get_scheduler(name='cosine',\n",
    "                          optimizer=optimizer,\n",
    "                          num_warmup_steps=0,\n",
    "                          num_training_steps=500)\n",
    "\n",
    "accelerator = Accelerator(gradient_accumulation_steps=16,\n",
    "                          mixed_precision='fp16')\n",
    "\n",
    "model_critic, loader, optimizer, scheduler = accelerator.prepare(\n",
    "    model_critic, loader, optimizer, scheduler)\n",
    "\n",
    "model_critic.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "7838daa6",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "99 7500 0.109375 4.998223681601473e-05 15.48046875 -0.436431884765625\n",
      "199 7500 0.004413604736328125 4.992897250651535e-05 24.7734375 -7.7060546875\n",
      "299 7500 0.0005574226379394531 4.984028276300021e-05 30.44140625 -15.005859375\n",
      "399 7500 0.00021004676818847656 4.9692208514878444e-05 30.9296875 -20.26171875\n",
      "499 7500 0.003726959228515625 4.952726293608335e-05 27.5546875 -12.6650390625\n",
      "599 7500 0.0013942718505859375 4.9327462774553166e-05 23.83203125 -10.3603515625\n",
      "699 7500 0.0019073486328125 4.909309195725025e-05 25.29296875 -14.05859375\n",
      "799 7500 0.00017786026000976562 4.877641290737884e-05 24.765625 -19.1171875\n",
      "899 7500 0.0009713172912597656 4.846834644134686e-05 21.21875 -15.98046875\n",
      "999 7500 0.00013566017150878906 4.812693017086145e-05 27.33203125 -15.447265625\n",
      "1099 7500 9.387731552124023e-05 4.775264926712489e-05 25.0234375 -17.279296875\n",
      "1199 7500 0.00011008977890014648 4.72751631047092e-05 27.03515625 -14.982421875\n",
      "1299 7500 0.0259857177734375 4.683156137024801e-05 24.6875 -7.521453857421875\n",
      "1399 7500 0.0004661083221435547 4.635693579248238e-05 22.43359375 -19.9453125\n",
      "1499 7500 8.189678192138672e-05 4.585196084032928e-05 24.140625 -18.63671875\n",
      "1599 7500 6.270408630371094e-05 4.522542485937369e-05 26.22265625 -16.216796875\n",
      "1699 7500 3.641843795776367e-05 4.465721080341547e-05 27.015625 -17.62109375\n",
      "1799 7500 0.0001685619354248047 4.40610627752862e-05 25.3984375 -14.48828125\n",
      "1899 7500 0.01073455810546875 4.343782793395435e-05 22.115234375 -14.630859375\n",
      "1999 7500 5.996227264404297e-05 4.267766952966369e-05 22.9453125 -19.3671875\n"
     ]
    }
   ],
   "source": [
    "for i, data in enumerate(loader):\n",
    "    with accelerator.accumulate(model_critic):\n",
    "        loss, value_chosen_sum, value_rejected_sum = model_critic(**data)\n",
    "        accelerator.backward(loss)\n",
    "\n",
    "        if accelerator.sync_gradients:\n",
    "            accelerator.clip_grad_norm_(model_critic.parameters(), 1.0)\n",
    "\n",
    "        optimizer.step()\n",
    "        scheduler.step()\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "    if (i + 1) % 100 == 0:\n",
    "        lr = optimizer.param_groups[0]['lr']\n",
    "        print(i, len(loader), loss.item(), lr, value_chosen_sum,\n",
    "              value_rejected_sum)\n",
    "\n",
    "    if i == 2000:\n",
    "        break\n",
    "\n",
    "torch.save(model_critic.to('cpu'), 'model/critic')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:cuda117]",
   "language": "python",
   "name": "conda-env-cuda117-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
